import numpy as np
import matplotlib.pyplot as plt

from utils import get_comm_mean_std, merge_comm

feducb_comm_bits = np.load('~/plt/feducb/comm_bits_list_mu10_0.npy')
print(feducb_comm_bits.shape)
feducb_comm_bits_mean, feducb_comm_bits_std = get_comm_mean_std(feducb_comm_bits)

des_comm_bits_10 = np.load('~/plt/des/comm_bits_lists_new_1e6_2.npy')
des_comm_bits_mean, des_comm_bits_std = get_comm_mean_std(des_comm_bits_10)

gossip_comm_bits = np.load('~/plt/gossip/comm_bits_list.npy')
print(f'gossip_comm_bits.shape:{gossip_comm_bits.shape}')
gossip_comm_bits_mean, gossip_comm_bits_std = get_comm_mean_std(gossip_comm_bits)

ducb_comm_bits = np.load('~/plt/ducb/comm_bits_list_50.npy')
print(f'ducb_comm_bits.shape:{ducb_comm_bits.shape}')
ducb_comm_bits_mean, ducb_comm_bits_std = get_comm_mean_std(ducb_comm_bits)

tomf_comm_bits = np.load('~/plt/tomf/comm_bits_list_hetero_1e6_50.npy')
tomf_comm_bits_mean, tomf_comm_bits_std = get_comm_mean_std(tomf_comm_bits)

plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.size'] = 20
plt.rcParams['axes.labelweight'] = 'bold'
plt.figure(figsize=(10, 7))

plt.plot(feducb_comm_bits_mean, linestyle='--', marker='^', color='blue', label='FedUCB',
         markersize=8, markerfacecolor='white', markeredgewidth=2, markevery=100000)
plt.fill_between(range(len(feducb_comm_bits_mean)),
                 feducb_comm_bits_mean - feducb_comm_bits_std,
                 feducb_comm_bits_mean + feducb_comm_bits_std,
                 color='blue', alpha=0.08)

plt.plot(tomf_comm_bits_mean, linestyle=':', marker='*', color='red', label='TOMF',
         markersize=12, markerfacecolor='white', markeredgewidth=2, markevery=100000)
plt.fill_between(range(len(tomf_comm_bits_mean)),
                 tomf_comm_bits_mean - tomf_comm_bits_std,
                 tomf_comm_bits_mean + tomf_comm_bits_std,
                 color='red', alpha=0.08)

plt.plot(gossip_comm_bits_mean, linestyle='-', marker='o', color='green', label='Gossip',
         markersize=8, markerfacecolor='white', markeredgewidth=2, markevery=100000)
plt.fill_between(range(len(gossip_comm_bits_mean)),
                 gossip_comm_bits_mean - gossip_comm_bits_std,
                 gossip_comm_bits_mean + gossip_comm_bits_std,
                 color='green', alpha=0.08)

plt.plot(des_comm_bits_mean, linestyle='-.', marker='D', color='purple', label='DES',
         markersize=8, markerfacecolor='white', markeredgewidth=2, markevery=100000)
plt.fill_between(range(len(des_comm_bits_mean)), 
                 des_comm_bits_mean - des_comm_bits_std, 
                 des_comm_bits_mean + des_comm_bits_std, 
                 color='purple', alpha=0.08)


plt.plot(ducb_comm_bits_mean, linestyle='--', marker='x', color='brown', label='Distributed-UCB', 
         markersize=10, markerfacecolor='white',  markeredgewidth=2, markevery=100000)
plt.fill_between(range(len(ducb_comm_bits_mean)), 
                 ducb_comm_bits_mean - ducb_comm_bits_std, 
                 ducb_comm_bits_mean + ducb_comm_bits_std, 
                 color='brown', alpha=0.08)


plt.xlabel('Time Slots')
plt.ylabel('Communication Bits')
plt.yscale('log') 
plt.yticks([1e2, 1e3, 1e4, 1e5, 1e6, 1e7, 1e8], ['1e2', '1e3', '1e4', '1e5', '1e6', '1e7', '1e8'])  # 自定义刻度
# plt.legend()
plt.grid(False)
plt.tight_layout()
plt.savefig('~/plt/Comm_Bits.png', dpi=300)
# plt.show()